{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training GNNs on Large Graphs\n",
    "\n",
    "We have seen the example of training GNNs on the entire graph.  However, usually our graph is very big: it could contain millions or billions of nodes and edges.  The storage required for the graph would be many times bigger if we consider node and edge features.  If we want to utilize GPUs for faster computation, we would notice that full graph training is often impossible on GPUs because our graph and features cannot fit into a single GPU.  Not to mention that the node representation of intermediate layers are also stored for the sake of backpropagation.\n",
    "\n",
    "To get over this limit, we employ two methodologies:\n",
    "\n",
    "1. Stochastic training on graphs.\n",
    "2. Neighbor sampling on graphs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stochastic Training on Graphs\n",
    "\n",
    "If you are familiar with deep learning for images/texts/etc., you should know stochastic gradient descent (SGD) very well.  In SGD, you sample a minibatch of examples, compute the loss on those examples only, find the gradients, and update the model parameters.\n",
    "\n",
    "Stochastic training on graphs resembles SGD on image/text datasets in the sense that one also samples a minibatch of nodes (or pair/tuple of nodes, depending on the task) and compute the loss on those nodes only.  The difference is that the output representation of a small set of nodes may depend on the input features of a substantially larger set of nodes.\n",
    "\n",
    "### GraphSAGE Recap\n",
    "\n",
    "In previous session, we have discussed GraphSAGE model.\n",
    "\n",
    "The output representation $\\boldsymbol{y}_u$ of node $u$ from a GraphSAGE layer is simply computed by:\n",
    "\n",
    "* Aggregating the input features of all neighbors of $u$ by for instance averaging.\n",
    "* Concatenating the aggregation with the node $u$'s representation itself.\n",
    "* Passing the concatenation to an MLP.\n",
    "\n",
    "And we have defined the following `GraphSAGEModel`. It leveraged DGL's built-in class `SAGEConv` and can compute multi-layer outputs on a whole graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "import dgl.function as fn\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dgl.nn.pytorch import conv as dgl_conv\n",
    "\n",
    "class GraphSAGEModel(nn.Module):\n",
    "    def __init__(self,\n",
    "                 in_feats,\n",
    "                 n_hidden,\n",
    "                 out_dim,\n",
    "                 n_layers,\n",
    "                 activation,\n",
    "                 dropout,\n",
    "                 aggregator_type):\n",
    "        super(GraphSAGEModel, self).__init__()\n",
    "        self.layers = nn.ModuleList()\n",
    "\n",
    "        # input layer\n",
    "        self.layers.append(dgl_conv.SAGEConv(in_feats, n_hidden, aggregator_type,\n",
    "                                         feat_drop=dropout, activation=activation))\n",
    "        # hidden layers\n",
    "        for i in range(n_layers - 1):\n",
    "            self.layers.append(dgl_conv.SAGEConv(n_hidden, n_hidden, aggregator_type,\n",
    "                                             feat_drop=dropout, activation=activation))\n",
    "        # output layer\n",
    "        self.layers.append(dgl_conv.SAGEConv(n_hidden, out_dim, aggregator_type,\n",
    "                                         feat_drop=dropout, activation=None))\n",
    "\n",
    "    def forward(self, g, features):\n",
    "        h = features\n",
    "        for layer in self.layers:\n",
    "            h = layer(g, h)\n",
    "        return h"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batching on a Graph\n",
    "\n",
    "For stochastic training, we want to split training data into small batches and only put necessary information into GPU for each step of training. In case of node classification, we want to split the labeld nodes into batches. Let take a deep look of what information is necessary for a batch of nodes.\n",
    "\n",
    "For instance, consider the following graph:\n",
    "\n",
    "![Graph](assets/graph.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A small graph\n",
    "\n",
    "import networkx as nx\n",
    "\n",
    "example_graph = nx.Graph(\n",
    "    [(0, 2), (0, 4), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10),\n",
    "     (1, 2), (1, 3), (1, 5), (2, 3), (2, 4), (2, 6), (3, 5),\n",
    "     (3, 8), (4, 7), (8, 9), (8, 11), (9, 10), (9, 11)])\n",
    "example_graph = dgl.graph(example_graph)\n",
    "# We also assign features for each node\n",
    "INPUT_FEATURES = 5\n",
    "OUTPUT_FEATURES = 6\n",
    "example_graph.ndata['features'] = torch.randn(12, INPUT_FEATURES)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we wish to compute the output representation of node 4 and 6 with a GraphSAGE layer, we actually need the input feature of node 4 and 6 themselves, as well as their neighbors (node 7, 0 and 2):\n",
    "\n",
    "![Graph](assets/graph_1layer_46.png)\n",
    "\n",
    "We can see that node 7, 0, and 2 will contribute to representation of node 4, while 0 and 2 will contribute to node 6.\n",
    "\n",
    "### Finding Neighbors of Nodes\n",
    "\n",
    "DGL provides an API: `dgl.in_subgraph`, that takes in a set of nodes and returns a graph consisting of all edges going to one of the given nodes.  Such a graph can exactly describe the computation dependency above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([0, 2, 7, 0, 2]), tensor([4, 4, 4, 6, 6]))\n"
     ]
    }
   ],
   "source": [
    "sampled_node_batch = torch.LongTensor([4, 6])   # These are the nodes whose outputs are to be computed\n",
    "sampled_graph = dgl.in_subgraph(example_graph, sampled_node_batch)\n",
    "print(sampled_graph.all_edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result above reads that node 0, 2 and 7 connects to node 4, while node 0 and 2 connects to node 6.\n",
    "\n",
    "The returned subgraph will look like this:\n",
    "\n",
    "![](assets/in_subgraph_1.png)\n",
    "\n",
    "Note that it is directed.\n",
    "\n",
    "### Message Passing on One Layer\n",
    "\n",
    "\n",
    "Nodes in such a sub graph can have two roles: \n",
    "* The output nodes, which only contain the nodes whose outputs are to be computed.\n",
    "* The input nodes, which contain the neighbors of those nodes.  For models such as GraphSAGE, the output nodes themselves are also input nodes.\n",
    "\n",
    "In later propagation, logic of information flow is quite different on nodes with different roles.\n",
    "DGL further provides a bipartite structure *block* to better reflect this feature. A sub graph can be easily converted to a block with function `dgl.to_block`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node ID of input nodes in original graph: tensor([4, 6, 0, 2, 7])\n",
      "Node ID of output nodes in original graph: tensor([4, 6])\n",
      "Edge connections: tensor([0, 2, 7, 0, 2]) tensor([4, 4, 4, 6, 6])\n"
     ]
    }
   ],
   "source": [
    "sampled_block = dgl.to_block(sampled_graph, sampled_node_batch)\n",
    "\n",
    "def print_block_info(sampled_block):\n",
    "    sampled_input_nodes = sampled_block.srcdata[dgl.NID]\n",
    "    print('Node ID of input nodes in original graph:', sampled_input_nodes)\n",
    "\n",
    "    sampled_output_nodes = sampled_block.dstdata[dgl.NID]\n",
    "    print('Node ID of output nodes in original graph:', sampled_output_nodes)\n",
    "\n",
    "    sampled_block_edges_src, sampled_block_edges_dst = sampled_block.all_edges()\n",
    "    # We need to map the src and dst node IDs in the blocks to the node IDs in the original graph.\n",
    "    sampled_block_edges_src_mapped = sampled_input_nodes[sampled_block_edges_src]\n",
    "    sampled_block_edges_dst_mapped = sampled_output_nodes[sampled_block_edges_dst]\n",
    "    print('Edge connections:', sampled_block_edges_src_mapped, sampled_block_edges_dst_mapped)\n",
    "    \n",
    "print_block_info(sampled_block)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the input nodes also include node 4 and 6, which are the output nodes themselves. And the edge connections are preserved (i.e. they map to the same ones in `sampled_graph`).\n",
    "\n",
    "Essentially, the block will have the following structure.  The input nodes reside on the left and the output nodes reside on the right.  Note that node 4 and 6 appear in both input nodes and output nodes.\n",
    "\n",
    "![](assets/block1.png)\n",
    "\n",
    "#### GraphSAGE Layer on Blocks\n",
    "\n",
    "The sampled block is ensantially a bipartite graph. We have seen in previous example that DGL's built-in class `SAGEConv` works perfectly on whole graph. Does it also function properly on a *Block*? The answer is yes. Acutally all of DGL's neural network layers support working on both homogeneous graphs and bipartite graphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.0000, 0.0000, 1.2657, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 1.5647, 2.1398, 0.1513, 0.0000, 0.2595]],\n",
      "       grad_fn=<ReluBackward0>)\n"
     ]
    }
   ],
   "source": [
    "import dgl.nn as dglnn\n",
    "sageconv_module = dglnn.SAGEConv(INPUT_FEATURES, OUTPUT_FEATURES, 'mean', activation=F.relu)\n",
    "\n",
    "sampled_block_src_features = example_graph.ndata['features'][sampled_block.srcdata[dgl.NID]]\n",
    "sampled_block_dst_features = example_graph.ndata['features'][sampled_block.dstdata[dgl.NID]]\n",
    "\n",
    "output_of_sampled_node_batch = sageconv_module(\n",
    "    sampled_block, (sampled_block_src_features, sampled_block_dst_features))\n",
    "print(output_of_sampled_node_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Writing Your Own Layers on Blocks\n",
    "\n",
    "The code of writing your own layer that works on blocks is very similar to the one you write for homogeneous graphs, except for a few differences described in the comments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SAGEConvBipartite(nn.Module):\n",
    "    def __init__(self, in_feats, out_feats):\n",
    "        super().__init__()\n",
    "        self.W = nn.Linear(2 * in_feats, out_feats)\n",
    "        \n",
    "    def forward(self, g, x):\n",
    "        # Because g is now a bipartite graph, we now need to input tensors, one on the input\n",
    "        # side and another on the output side.\n",
    "        x_src, x_dst = x\n",
    "        \n",
    "        with g.local_scope():\n",
    "            # Aggregate input features of neighbors\n",
    "            g.srcdata['x'] = x_src                                      # ndata here is changed to srcdata\n",
    "                                                                        # x is also changed to x_src (input side)\n",
    "            g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'x_neigh'))\n",
    "            # Concatenate aggregation with the node representation itself\n",
    "            x_neigh = g.dstdata['x_neigh']                              # ndata here is changed to dstdata\n",
    "            x_concat = torch.cat([x_dst, x_neigh], -1)                  # x is changed to x_dst (output side)\n",
    "            # Pass the concatenation to an MLP\n",
    "            y = F.relu(self.W(x_concat))\n",
    "            return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple Layers\n",
    "\n",
    "Now we wish to compute the output of node 4 and 6 from a 2-layer GraphSAGE.  This requires the input features of not only the nodes themselves and their neighbors, but also the neighbors of these neighbors.\n",
    "\n",
    "![](assets/graph_2layer_46.png)\n",
    "\n",
    "To compute the 2-layer output of node 4 and 6, we first need to obtain the 1-layer output of node 4 and 6, as well as the neighbors (node 7, 0, and 2).  To obtain the 1-layer output of all these nodes, we again need the input feature of these nodes (node 4, 6, 7, 0, 2) as well as *their* neighbors (node 10, 9, 8, 1, and 3).\n",
    "\n",
    "Recall that to compute the 2-layer representation of node 4 and 6, the computation dependency on 2nd layer include node 4 and 6 as well as their neighbors:\n",
    "\n",
    "![](assets/in_subgraph_1.png)\n",
    "\n",
    "To compute the 1st layer representation of those nodes we further need more dependencies:\n",
    "\n",
    "![](assets/in_subgraph_2.png)\n",
    "\n",
    "If we draw the propagation flow as bipartite graphs, they will look like this:\n",
    "\n",
    "![](assets/block2.png)\n",
    "\n",
    "We can see that the generation of computation dependency for multi-layer GNNs is a bottom-up process: we start from the output layer, and grows the node set towards the input layer.\n",
    "\n",
    "The following code directly returns the list of blocks as the computation dependency generation for multi-layer GNNs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FullNeighborBlockSampler(object):\n",
    "    def __init__(self, g, num_layers):\n",
    "        self.g = g\n",
    "        self.num_layers = num_layers\n",
    "        \n",
    "    def sample(self, seeds):\n",
    "        blocks = []\n",
    "        for i in range(self.num_layers):\n",
    "            sampled_graph = dgl.in_subgraph(self.g, seeds)\n",
    "            sampled_block = dgl.to_block(sampled_graph, seeds)\n",
    "            seeds = sampled_block.srcdata[dgl.NID]\n",
    "            # Because the computation dependency is generated bottom-up, we prepend the new block instead of\n",
    "            # appending it.\n",
    "            blocks.insert(0, sampled_block)\n",
    "            \n",
    "        return blocks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Block for first layer\n",
      "---------------------\n",
      "Node ID of input nodes in original graph: tensor([ 4,  6,  0,  2,  7,  8,  9, 10,  1,  3])\n",
      "Node ID of output nodes in original graph: tensor([4, 6, 0, 2, 7])\n",
      "Edge connections: tensor([ 0,  2,  7,  0,  2,  2,  4,  6,  7,  8,  9, 10,  0,  4,  6,  1,  3,  0,\n",
      "         4]) tensor([4, 4, 4, 6, 6, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 7, 7])\n",
      "\n",
      "Block for second layer\n",
      "----------------------\n",
      "Node ID of input nodes in original graph: tensor([4, 6, 0, 2, 7])\n",
      "Node ID of output nodes in original graph: tensor([4, 6])\n",
      "Edge connections: tensor([0, 2, 7, 0, 2]) tensor([4, 4, 4, 6, 6])\n"
     ]
    }
   ],
   "source": [
    "block_sampler = FullNeighborBlockSampler(example_graph, 2)\n",
    "sampled_blocks = block_sampler.sample(sampled_node_batch)\n",
    "\n",
    "print('Block for first layer')\n",
    "print('---------------------')\n",
    "print_block_info(sampled_blocks[0])\n",
    "print()\n",
    "print('Block for second layer')\n",
    "print('----------------------')\n",
    "print_block_info(sampled_blocks[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The message propagation is instead a top-down process, as opposed to computation dependency generation: we start from the input layer, and computes the representations towards the output layer.\n",
    "\n",
    "Now we modify our GraphSAGEModel so it can forward on blocks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SAGENet(nn.Module):\n",
    "    def __init__(self, n_layers, in_feats, out_feats, hidden_feats=None):\n",
    "        super().__init__()\n",
    "        self.convs = nn.ModuleList()\n",
    "        \n",
    "        if hidden_feats is None:\n",
    "            hidden_feats = out_feats\n",
    "        \n",
    "        if n_layers == 1:\n",
    "            self.convs.append(dglnn.SAGEConv(in_feats, out_feats, 'mean'))\n",
    "        else:\n",
    "            self.convs.append(dglnn.SAGEConv(in_feats, hidden_feats, 'mean', activation=F.relu))\n",
    "            for i in range(n_layers - 2):\n",
    "                self.convs.append(dglnn.SAGEConv(hidden_feats, hidden_feats, 'mean', activation=F.relu))\n",
    "            self.convs.append(dglnn.SAGEConv(hidden_feats, out_feats, 'mean'))\n",
    "        \n",
    "    def forward(self, blocks, input_features):\n",
    "        \"\"\"\n",
    "        blocks : List of blocks generated by block sampler.\n",
    "        input_features : Input features of the first block.\n",
    "        \"\"\"\n",
    "        h = input_features\n",
    "        for layer, block in zip(self.convs, blocks):\n",
    "            h = self.propagate(block, h, layer)\n",
    "        return h\n",
    "    \n",
    "    def propagate(self, block, h, layer):\n",
    "        # Because GraphSAGE requires not only the features of the neighbors, but also the features\n",
    "        # of the output nodes themselves on the current layer, we need to copy the output node features\n",
    "        # from the input side to the output side ourselves to make GraphSAGE work correctly.\n",
    "        # The output nodes of a block are guaranteed to appear the first in the input nodes, so we can\n",
    "        # conveniently write like this:\n",
    "        h_dst = h[:block.number_of_dst_nodes()]\n",
    "        h = layer(block, (h, h_dst))\n",
    "        return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 2.5556,  0.2520, -0.7534,  0.7471,  2.5092,  3.3257],\n",
      "        [ 1.1765,  0.5756, -1.0913, -0.1211,  0.5537,  1.8392]],\n",
      "       grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "sagenet = SAGENet(2, INPUT_FEATURES, OUTPUT_FEATURES)\n",
    "\n",
    "# The input nodes for computing 2-layer GraphSAGE output on the given output nodes can be obtained like this:\n",
    "sampled_input_nodes = sampled_blocks[0].srcdata[dgl.NID]\n",
    "\n",
    "# Get the input features.\n",
    "# In real life we want to copy this to GPU.  But in this hands-on tutorial we don't have GPUs.\n",
    "sampled_input_features = example_graph.ndata['features'][sampled_input_nodes]\n",
    "\n",
    "output_of_sampled_node_batch = sagenet(sampled_blocks, sampled_input_features)\n",
    "print(output_of_sampled_node_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neighborhood Sampling\n",
    "\n",
    "We may notice in the above example that 2-hop neighbors actually almost covered the entire graph.  In real world graphs whose node degrees often follow a power-law distribution (i.e. there would exist a few \"hub\" nodes with lots of edges), we indeed often observe that for a small set of output nodes from a multi-layer GNN, the input nodes will still cover a large part of the graph.  The whole purpose of saving GPU memory thus fails again in this setting.\n",
    "\n",
    "Neighborhood sampling offers a solution by *not* taking all neighbors for every node during computation dependency generation.  Instead, we pick a small subset of neighbors and estimate the aggregation of all neighbors from this subset.  This trick often not only reduces memory consumption, but also improves model generalization.\n",
    "\n",
    "DGL provides a function `dgl.sampling.sample_neighbors` for uniform sampling a fixed number of neighbors of each node.  One can also change `dgl.sampling.sample_neighbors` to any kind of existing neighborhood sampling algorithm (including your own)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeighborSampler(object):\n",
    "    def __init__(self, g, num_fanouts):\n",
    "        \"\"\"\n",
    "        num_fanouts : list of fanouts on each layer.\n",
    "        \"\"\"\n",
    "        self.g = g\n",
    "        self.num_fanouts = num_fanouts\n",
    "        \n",
    "    def sample(self, seeds):\n",
    "        seeds = torch.LongTensor(seeds)\n",
    "        blocks = []\n",
    "        for fanout in reversed(self.num_fanouts):\n",
    "            # We simply switch from in_subgraph to sample_neighbors for neighbor sampling.\n",
    "            sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)\n",
    "            \n",
    "            sampled_block = dgl.to_block(sampled_graph, seeds)\n",
    "            seeds = sampled_block.srcdata[dgl.NID]\n",
    "            # Because the computation dependency is generated bottom-up, we prepend the new block instead of\n",
    "            # appending it.\n",
    "            blocks.insert(0, sampled_block)\n",
    "            \n",
    "        return blocks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Block for first layer\n",
      "---------------------\n",
      "Node ID of input nodes in original graph: tensor([ 4,  6,  7,  2,  0,  1,  3, 10])\n",
      "Node ID of output nodes in original graph: tensor([4, 6, 7, 2, 0])\n",
      "Edge connections: tensor([ 0,  7,  0,  2,  0,  4,  1,  3,  2, 10]) tensor([4, 4, 6, 6, 7, 7, 2, 2, 0, 0])\n",
      "\n",
      "Block for second layer\n",
      "----------------------\n",
      "Node ID of input nodes in original graph: tensor([4, 6, 7, 2, 0])\n",
      "Node ID of output nodes in original graph: tensor([4, 6])\n",
      "Edge connections: tensor([7, 2, 0, 2]) tensor([4, 4, 6, 6])\n"
     ]
    }
   ],
   "source": [
    "block_sampler = NeighborSampler(example_graph, [2, 2])\n",
    "sampled_blocks = block_sampler.sample(sampled_node_batch)\n",
    "\n",
    "print('Block for first layer')\n",
    "print('---------------------')\n",
    "print_block_info(sampled_blocks[0])\n",
    "print()\n",
    "print('Block for second layer')\n",
    "print('----------------------')\n",
    "print_block_info(sampled_blocks[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that each output node now has at most 2 neighbors.\n",
    "\n",
    "Code for message passing on blocks generated with neighborhood sampling does not change at all."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.4713,  1.4072,  2.4270,  2.9882,  2.9500, -1.4210],\n",
      "        [ 0.5040,  0.1228,  1.1560,  0.7806,  1.7027, -0.5086]],\n",
      "       grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "sagenet = SAGENet(2, INPUT_FEATURES, OUTPUT_FEATURES)\n",
    "\n",
    "# The input nodes for computing 2-layer GraphSAGE output on the given output nodes can be obtained like this:\n",
    "sampled_input_nodes = sampled_blocks[0].srcdata[dgl.NID]\n",
    "\n",
    "# Get the input features.\n",
    "# In real life we want to copy this to GPU.  But in this hands-on tutorial we don't have GPUs.\n",
    "sampled_input_features = example_graph.ndata['features'][sampled_input_nodes]\n",
    "\n",
    "output_of_sampled_node_batch = sagenet(sampled_blocks, sampled_input_features)\n",
    "print(output_of_sampled_node_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference with Models Trained with Neighbor Sampling\n",
    "\n",
    "Recall that modules such as Dropout or batch normalization have different formulations in training and inference.  The reason was that we do not wish to introduce any randomness during inference or model deployment.  Similarly, we do not want to sample any of the neighbors during inference; aggregation should be performed on all neighbors without sampling to eliminate randomness.  However, directly using the multi-layer `FullNeighborBlockSampler` would still cost a lot of memory even during inference, due to the large number of input nodes being covered.\n",
    "\n",
    "The solution to this is to compute representations of all nodes on one intermediate layer at a time.  To be more specific, for a multi-layer GraphSAGE model, we first compute the representation of all nodes on the 1st GraphSAGE layer, using a 1-layer `FullNeighborBlockSampler` to take all neighbors into account.  Such representations are computed in minibatches.  After all the representations from the 1st GraphSAGE layer are computed, we start from there and compute the representation of all nodes on the 2nd GraphSAGE layer.  We repeat the process until we go to the last layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def inference_with_sagenet(sagenet, graph, input_features, batch_size):\n",
    "    block_sampler = FullNeighborBlockSampler(graph, 1)\n",
    "    h = input_features\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        # We are computing all representations of one layer at a time.\n",
    "        # The outer loop iterates over GNN layers.\n",
    "        for conv in sagenet.convs:\n",
    "            new_h_list = []\n",
    "            node_ids = torch.arange(graph.number_of_nodes())\n",
    "            # The inner loop iterates over batch of nodes.\n",
    "            for batch_start in range(0, graph.number_of_nodes(), batch_size):\n",
    "                # Sample a block with full neighbors of the current node batch\n",
    "                block = block_sampler.sample(node_ids[batch_start:batch_start+batch_size])[0]\n",
    "                # Get the necessary input node IDs for this node batch on this layer\n",
    "                input_node_ids = block.srcdata[dgl.NID]\n",
    "                # Get the input features\n",
    "                h_input = h[input_node_ids]\n",
    "                # Compute the output of this node batch on this layer\n",
    "                new_h = sagenet.propagate(block, h_input, conv)\n",
    "                new_h_list.append(new_h)\n",
    "            # We finished computing all representations on this layer.  We need to compute the\n",
    "            # representations of next layer.\n",
    "            h = torch.cat(new_h_list)\n",
    "        \n",
    "    return h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.7681e+00, -3.3778e-02,  8.7602e-03, -9.2630e-01, -1.4338e+00,\n",
      "          1.1513e+00],\n",
      "        [-8.6763e-01,  5.2748e-01,  1.5401e-01,  1.1183e+00,  4.7106e-01,\n",
      "          3.1708e-01],\n",
      "        [ 6.2210e-01,  1.4617e-02,  7.6488e-01,  2.8973e-01,  6.6038e-01,\n",
      "          7.6737e-01],\n",
      "        [ 1.9717e+00, -4.4922e-01, -1.0719e-03, -8.4052e-01, -2.3791e-01,\n",
      "          6.2204e-01],\n",
      "        [ 1.2272e+00,  1.1383e+00,  2.0385e+00,  2.7903e+00,  2.5650e+00,\n",
      "         -7.2432e-01],\n",
      "        [ 7.9594e-01,  8.5921e-02,  3.5123e-01,  4.9928e-01,  9.7032e-01,\n",
      "          8.0882e-01],\n",
      "        [ 5.9009e-02,  4.4664e-01,  9.2594e-01,  1.3676e+00,  1.1535e+00,\n",
      "          4.4968e-01],\n",
      "        [ 1.6940e+00,  8.8744e-02,  1.0162e+00,  7.7402e-01,  8.0411e-01,\n",
      "          9.1836e-01],\n",
      "        [-2.1243e-01,  1.0425e+00,  8.5567e-01,  1.0732e+00,  1.9559e-01,\n",
      "          1.4029e+00],\n",
      "        [ 1.7215e-01,  1.8499e+00,  3.0013e+00,  2.5618e+00,  3.0003e+00,\n",
      "         -8.3607e-01],\n",
      "        [ 4.7615e-01,  1.0388e+00,  2.3129e+00,  9.2240e-01,  1.6990e+00,\n",
      "          4.0923e-01],\n",
      "        [ 9.1681e-01, -1.5961e-01, -8.3462e-01, -2.1443e+00, -3.8410e-01,\n",
      "          2.2363e+00]])\n"
     ]
    }
   ],
   "source": [
    "print(inference_with_sagenet(sagenet, example_graph, example_graph.ndata['features'], 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Putting Together\n",
    "\n",
    "Now let's see how we could apply stochastic training on a node classification task.  We take PubMed dataset as an example.\n",
    "\n",
    "### Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished data loading and preprocessing.\n",
      "  NumNodes: 19717\n",
      "  NumEdges: 88651\n",
      "  NumFeats: 500\n",
      "  NumClasses: 3\n",
      "  NumTrainingSamples: 60\n",
      "  NumValidationSamples: 500\n",
      "  NumTestSamples: 1000\n"
     ]
    }
   ],
   "source": [
    "import dgl.data\n",
    "\n",
    "dataset = dgl.data.citation_graph.load_pubmed()\n",
    "\n",
    "# Set features and labels for each node\n",
    "graph = dgl.graph(dataset.graph)\n",
    "graph.ndata['features'] = torch.FloatTensor(dataset.features)\n",
    "graph.ndata['labels'] = torch.LongTensor(dataset.labels)\n",
    "\n",
    "# Find the node IDs in the training, validation, and test set.\n",
    "train_nid = dataset.train_mask.nonzero()[0]\n",
    "val_nid = dataset.val_mask.nonzero()[0]\n",
    "test_nid = dataset.test_mask.nonzero()[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Neighbor Sampler\n",
    "\n",
    "We can reuse our neighbor sampler code above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "neighbor_sampler = NeighborSampler(graph, [10, 10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define DataLoader\n",
    "\n",
    "PyTorch generates minibatches with a `DataLoader` object.  We can also use it.\n",
    "\n",
    "Note that to compute the output of a minibatch of nodes, we need a list of blocks described as above.  Therefore, we need to change the `collate_fn` argument which defines how to compose different individual examples into a minibatch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.utils.data\n",
    "\n",
    "BATCH_SIZE = 5\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    train_nid, batch_size=BATCH_SIZE, collate_fn=neighbor_sampler.sample, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Model and Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "HIDDEN_FEATURES = 10\n",
    "model = SAGENet(2, dataset.features.shape[1], dataset.num_labels, HIDDEN_FEATURES)\n",
    "\n",
    "opt = torch.optim.Adam(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(pred, labels):\n",
    "    return (pred.argmax(1) == labels).float().mean().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation acc: 0.3919999897480011 Test acc: 0.41600000858306885\n",
      "Validation acc: 0.3959999978542328 Test acc: 0.4180000126361847\n",
      "Validation acc: 0.40799999237060547 Test acc: 0.42500001192092896\n",
      "Validation acc: 0.4359999895095825 Test acc: 0.4390000104904175\n",
      "Validation acc: 0.4860000014305115 Test acc: 0.4860000014305115\n",
      "Validation acc: 0.6200000047683716 Test acc: 0.6029999852180481\n",
      "Validation acc: 0.6759999990463257 Test acc: 0.6769999861717224\n",
      "Validation acc: 0.6620000004768372 Test acc: 0.6299999952316284\n",
      "Validation acc: 0.6880000233650208 Test acc: 0.6729999780654907\n",
      "Validation acc: 0.7239999771118164 Test acc: 0.7129999995231628\n",
      "Validation acc: 0.7080000042915344 Test acc: 0.7070000171661377\n",
      "Validation acc: 0.7099999785423279 Test acc: 0.7099999785423279\n",
      "Validation acc: 0.6959999799728394 Test acc: 0.6880000233650208\n",
      "Validation acc: 0.6980000138282776 Test acc: 0.6990000009536743\n",
      "Validation acc: 0.7260000109672546 Test acc: 0.7089999914169312\n",
      "Validation acc: 0.6959999799728394 Test acc: 0.6850000023841858\n",
      "Validation acc: 0.7139999866485596 Test acc: 0.7039999961853027\n",
      "Validation acc: 0.7120000123977661 Test acc: 0.6970000267028809\n",
      "Validation acc: 0.7120000123977661 Test acc: 0.6959999799728394\n",
      "Validation acc: 0.7120000123977661 Test acc: 0.7020000219345093\n",
      "Validation acc: 0.7099999785423279 Test acc: 0.6949999928474426\n",
      "Validation acc: 0.7139999866485596 Test acc: 0.7009999752044678\n",
      "Validation acc: 0.7139999866485596 Test acc: 0.699999988079071\n",
      "Validation acc: 0.7139999866485596 Test acc: 0.6990000009536743\n",
      "Validation acc: 0.7139999866485596 Test acc: 0.7020000219345093\n",
      "Validation acc: 0.7279999852180481 Test acc: 0.7089999914169312\n",
      "Validation acc: 0.722000002861023 Test acc: 0.7059999704360962\n",
      "Validation acc: 0.7300000190734863 Test acc: 0.7200000286102295\n",
      "Validation acc: 0.7279999852180481 Test acc: 0.7089999914169312\n",
      "Validation acc: 0.7279999852180481 Test acc: 0.7099999785423279\n",
      "Validation acc: 0.7300000190734863 Test acc: 0.718999981880188\n",
      "Validation acc: 0.7360000014305115 Test acc: 0.7239999771118164\n",
      "Validation acc: 0.7379999756813049 Test acc: 0.7279999852180481\n",
      "Validation acc: 0.7379999756813049 Test acc: 0.7279999852180481\n",
      "Validation acc: 0.7419999837875366 Test acc: 0.7329999804496765\n",
      "Validation acc: 0.7480000257492065 Test acc: 0.734000027179718\n",
      "Validation acc: 0.7519999742507935 Test acc: 0.7350000143051147\n",
      "Validation acc: 0.7459999918937683 Test acc: 0.7379999756813049\n",
      "Validation acc: 0.7480000257492065 Test acc: 0.734000027179718\n",
      "Validation acc: 0.7519999742507935 Test acc: 0.7409999966621399\n",
      "Validation acc: 0.75 Test acc: 0.7400000095367432\n",
      "Validation acc: 0.7519999742507935 Test acc: 0.7429999709129333\n",
      "Validation acc: 0.7559999823570251 Test acc: 0.7429999709129333\n",
      "Validation acc: 0.7680000066757202 Test acc: 0.7490000128746033\n",
      "Validation acc: 0.7580000162124634 Test acc: 0.7459999918937683\n",
      "Validation acc: 0.7680000066757202 Test acc: 0.7490000128746033\n",
      "Validation acc: 0.7639999985694885 Test acc: 0.7490000128746033\n",
      "Validation acc: 0.7639999985694885 Test acc: 0.7480000257492065\n",
      "Validation acc: 0.7680000066757202 Test acc: 0.7490000128746033\n",
      "Validation acc: 0.7620000243186951 Test acc: 0.753000020980835\n"
     ]
    }
   ],
   "source": [
    "NUM_EPOCHS = 50\n",
    "EVAL_BATCH_SIZE = 1000\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    sagenet.train()\n",
    "    for blocks in train_dataloader:\n",
    "        input_nodes = blocks[0].srcdata[dgl.NID]\n",
    "        output_nodes = blocks[-1].dstdata[dgl.NID]\n",
    "        \n",
    "        input_features = graph.ndata['features'][input_nodes]\n",
    "        output_labels = graph.ndata['labels'][output_nodes]\n",
    "        \n",
    "        output_predictions = model(blocks, input_features)\n",
    "        loss = F.cross_entropy(output_predictions, output_labels)\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        \n",
    "    sagenet.eval()\n",
    "    all_predictions = inference_with_sagenet(model, graph, graph.ndata['features'], EVAL_BATCH_SIZE)\n",
    "    \n",
    "    val_predictions = all_predictions[val_nid]\n",
    "    val_labels = graph.ndata['labels'][val_nid]\n",
    "    test_predictions = all_predictions[test_nid]\n",
    "    test_labels = graph.ndata['labels'][test_nid]\n",
    "    \n",
    "    print('Validation acc:', compute_accuracy(val_predictions, val_labels),\n",
    "          'Test acc:', compute_accuracy(test_predictions, test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
